load("@rules_python//python:packaging.bzl", "py_wheel")

cuda_wheels = [
    "cuda12",
    "cuda13",
]

cpu_wheel = ["cpu"]

# Aliases and filegroups don't move files within bazel-out, so we need to use genrules to place them
# in the correct directory structure for the wheel.
[
    genrule(
        name = "init_file_" + arch,
        srcs = ["__init__.py"],
        outs = ["xla_plugins/xla_" + arch + "_pjrt/__init__.py"],
        cmd = "cp $< $@",
    )
    for arch in cuda_wheels + cpu_wheel
]

# GPU-specific files
[
    cc_binary(
        name = "xla_plugins/xla_" + arch + "_pjrt/xla_gpu_pjrt.so",
        linkopts = [
            "-Wl,--version-script,$(location :pjrt_symbols.lds)",
            "-Wl,--no-undefined",
        ],
        linkshared = True,
        deps = [
            ":pjrt_symbols.lds",
            "//xla/pjrt/c:pjrt_c_api_gpu",
        ],
    )
    for arch in cuda_wheels
]

# GPU wheels
[
    py_wheel(
        name = "xla_" + arch + "_pjrt",
        author = "The OpenXLA Authors",
        classifiers = ["Development Status :: 1 - Planning"],
        distribution = "xla_" + arch + "_pjrt",
        entry_points = {
            "xla_plugins": [
                "xla_" + arch + "_pjrt = xla_plugins.xla_" + arch + "_pjrt\n",
            ],
        },
        platform = "manylinux2014_x86_64",
        python_tag = "py3",
        strip_path_prefixes = ["build_tools/pjrt_wheels"],
        summary = "XLA PJRT Plugin",
        version = "0.0.0.dev0",
        deps = [
            ":xla_plugins/xla_" + arch + "_pjrt/xla_gpu_pjrt.so",
            ":init_file_" + arch,
        ],
    )
    for arch in cuda_wheels
]

# CPU-specific files
[
    # TODO: Replace this with alias to .so found in xla/pjrt/c/BUILD
    cc_binary(
        name = "xla_plugins/xla_" + arch + "_pjrt/xla_cpu_pjrt.so",
        linkopts = [
            "-Wl,--version-script,$(location :pjrt_symbols.lds)",
            "-Wl,--no-undefined",
        ],
        linkshared = True,
        deps = [
            ":pjrt_symbols.lds",
            "//xla/pjrt/c:pjrt_c_api_cpu",
        ],
    )
    for arch in cpu_wheel
]

# CPU wheel
[
    py_wheel(
        name = "xla_" + arch + "_pjrt",
        author = "The OpenXLA Authors",
        classifiers = ["Development Status :: 1 - Planning"],
        distribution = "xla_" + arch + "_pjrt",
        entry_points = {
            "xla_plugins": [
                "xla_" + arch + "_pjrt = xla_plugins.xla_" + arch + "_pjrt\n",
            ],
        },
        platform = "manylinux2014_x86_64",
        python_tag = "py3",
        strip_path_prefixes = ["build_tools/pjrt_wheels"],
        summary = "XLA PJRT Plugin",
        version = "0.0.0.dev0",
        deps = [
            ":xla_plugins/xla_" + arch + "_pjrt/xla_cpu_pjrt.so",
            ":init_file_" + arch,
        ],
    )
    for arch in cpu_wheel
]
